# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Send/Recv ops.

The following _Send()/_Recv() are adapted from python op wrappers
generated by python_op_gen_main. python_op_gen_main.cc's
PrintAllPythonOps needs to be updated to export internal ops.
"""

from lingvo import compat as tf

# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.tf2xla.python import xla
# pylint: enable=g-direct-tensorflow-import


def _TpuCore(device):
  """Returns the TPU core represented by <device>, or -1 if not TPU."""
  prefix = "device:TPU_REPLICATED_CORE:"
  if prefix in device:
    return int(device[len(prefix):])
  return -1


class Channel:
  """A communication channel to transfer tensors in order."""

  def __init__(self, dtype, shape, send_device, recv_device, name=None):
    """Construct a channel.

    Args:
      dtype: The dtype of tensors sent through the channel.
      shape: The shape of tensors sent through the channel. Must be a fully
        defined shape for TPUs.
      send_device: A fully-specified tensorflow device.
      recv_device: A fully-specified tensorflow device.
      name: A name for the channel (optional).
    """
    current_graph = tf.get_default_graph()
    assert current_graph, "A channel is scoped within a tf.Graph"
    self._dtype = dtype
    self._send_device = send_device
    self._recv_device = recv_device
    self._name = current_graph.unique_name(name if name else "channel")

    assert shape is not None
    shape = tf.TensorShape(shape)

    self._shape = shape
    self._send_tpu_core = _TpuCore(send_device)
    self._recv_tpu_core = _TpuCore(recv_device)
    self._send_called = False
    self._recv_op = None
    assert ((self._send_tpu_core == -1) == (self._recv_tpu_core == -1)), (
        "Mixing TPU and non-TPU: %s and %s" % (send_device, recv_device))
    if self._send_tpu_core >= 0:
      assert self._shape.is_fully_defined(), (
          "TPU channel must have fully defined shape. Name: %s, shape: %s" %
          (self._name, self._shape))
      assert self._send_tpu_core != self._recv_tpu_core, (
          "TPU send/recv must be cross-core: %s and %s" %
          (send_device, recv_device))

  def Send(self, tensor):
    """Sends a tensor through the channel."""
    assert tensor.dtype == self._dtype
    assert not self._send_called, ("Send called multiple times for %s" %
                                   self._name)
    self._send_called = True
    if self._send_tpu_core == -1:
      return tf.raw_ops.Send(
          tensor=tensor,
          tensor_name=self._name,
          send_device=self._send_device,
          send_device_incarnation=0,
          recv_device=self._recv_device)
    else:
      with tf.device(self._send_device):
        return xla.send(
            tensor, tensor_name=self._name, name="Send_" + self._name)

  def Recv(self):
    """Receives a tensor from the channel."""
    if self._send_tpu_core == -1:
      received = tf.raw_ops.Recv(
          tensor_type=self._dtype,
          tensor_name=self._name,
          send_device=self._send_device,
          send_device_incarnation=0,
          recv_device=self._recv_device)
      received.set_shape(self._shape)
      return received
    else:
      with tf.device(self._recv_device):
        return xla.recv(
            self._dtype,
            tensor_name=self._name,
            shape=self._shape,
            name="Recv_" + self._name)
